package toms748

import (
	"errors"
	"math"

	"gitee.com/extrame/math/tools"
)

type EpsTolerance float64

var maxFloat = math.Pow(10, 20)
var minFloat = math.Pow(10, -20)

func (e EpsTolerance) Judge(a, b float64) bool {
	return math.Abs(a-b) <= float64(e)*math.Min(math.Abs(a), math.Abs(b))
}

func bracket(f func(float64) float64, a, b, c, fa, fb float64) (float64, float64, float64, float64, float64, float64) {
	var d, fd float64
	tol := tools.Epsilon() * 2
	//
	// If the interval [a,b] is very small, or if c is too close
	// to one end of the interval then we need to adjust the
	// location of c accordingly:
	//
	if (b - a) < 2*tol*a {
		c = a + (b-a)/2
	} else if c <= a+math.Abs(a)*tol {
		c = a + math.Abs(a)*tol
	} else if c >= b-math.Abs(b)*tol {
		c = b - math.Abs(b)*tol
	}
	//
	// OK, lets invoke f(c):
	//
	fc := f(c)
	//
	// if we have a zero then we have an exact solution to the root:
	//
	if fc == 0 {
		a = c
		fa = 0
		d = 0
		fd = 0
		return a, b, fa, fb, d, fd
	}
	//
	// Non-zero fc, update the interval:
	//
	if fa < 0 && fc >= 0 || fa >= 0 && fc < 0 {
		d = b
		fd = fb
		b = c
		fb = fc
	} else {
		d = a
		fd = fa
		a = c
		fa = fc
	}
	return a, b, fa, fb, d, fd
}

func safe_div(num, denom, r float64) float64 {

	if math.Abs(denom) < 1 {
		if math.Abs(denom*maxFloat) <= math.Abs(num) {
			return r
		}
	}
	return num / denom
}

// template <class T>
func secantInterpolate(a, b, fa, fb float64) float64 {

	var tol = math.E * 5
	c := a - (fa/(fb-fa))*(b-a)
	if (c <= a+math.Abs(a)*tol) || (c >= b-math.Abs(b)*tol) {
		return (a + b) / 2
	}
	return c
}

// template <class T>
func quadraticInterpolate(a, b, d, fa, fb, fd float64, count int) float64 {
	//
	// Performs quadratic interpolation to determine the next point,
	// takes count Newton steps to find the location of the
	// quadratic polynomial.
	//
	// Point d must lie outside of the interval [a,b], it is the third
	// best approximation to the root, after a and b.
	//
	// Note: this does not guarantee to find a root
	// inside [a, b], so we fall back to a secant step should
	// the result be out of range.
	//
	// Start by obtaining the coefficients of the quadratic polynomial:
	//
	B := safe_div(fb-fa, b-a, maxFloat)
	A := safe_div(fd-fb, d-b, maxFloat)
	A = safe_div(A-B, d-a, 0)

	if A == 0 {
		// failure to determine coefficients, try a secant step:
		return secantInterpolate(a, b, fa, fb)
	}
	//
	// Determine the starting point of the Newton steps:
	//
	var c float64
	if (A < 0 && fa < 0) || (A >= 0 && fa >= 0) {
		c = a
	} else {
		c = b
	}
	//
	// Take the Newton steps:
	//
	for i := 0; i < count; i++ {
		c -= safe_div(fa+(B+A*(c-b))*(c-a), B+A*(2*c-a-b), 1+c-a)
	}
	if (c <= a) || (c >= b) {
		// Oops, failure, try a secant step:
		c = secantInterpolate(a, b, fa, fb)
	}
	return c
}

func cubicInterpolate(a, b, d,
	e, fa, fb,
	fd, fe float64) float64 {
	//
	// Uses inverse cubic interpolation of f(x) at points
	// [a,b,d,e] to obtain an approximate root of f(x).
	// Points d and e lie outside the interval [a,b]
	// and are the third and forth best approximations
	// to the root that we have found so far.
	//
	// Note: this does not guarantee to find a root
	// inside [a, b], so we fall back to quadratic
	// interpolation in case of an erroneous result.
	//
	q11 := (d - e) * fd / (fe - fd)
	q21 := (b - d) * fb / (fd - fb)
	q31 := (a - b) * fa / (fb - fa)
	d21 := (b - d) * fd / (fd - fb)
	d31 := (a - b) * fb / (fb - fa)

	q22 := (d21 - q11) * fb / (fe - fb)
	q32 := (d31 - q21) * fa / (fd - fa)
	d32 := (d31 - q21) * fd / (fd - fa)
	q33 := (d32 - q22) * fa / (fe - fa)
	c := q31 + q32 + q33 + a

	if (c <= a) || (c >= b) {
		// Out of bounds step, fall back to quadratic interpolation:
		c = quadraticInterpolate(a, b, d, fa, fb, fd, 3)
	}

	return c
}

func Toms748(f func(float64) float64, a, b float64, tol EpsTolerance, max_iter int) (float64, float64, error) {
	if max_iter <= 2 {
		return a, b, nil
	}
	max_iter -= 2
	a, b, err := Toms748WithF(f, a, b, f(a), f(b), tol, max_iter)
	max_iter += 2
	return a, b, err
}

func Toms748WithF(f func(float64) float64, a, b float64, fa, fb float64, tol EpsTolerance, max_iter int) (float64, float64, error) {
	var c, u, fu, a0, b0, d, fd, e, fe = float64(0), float64(0), float64(0), float64(0), float64(0), float64(0), 1e+5, 1e+5, 1e+5
	if max_iter == 0 {
		return a, b, nil
	}
	count := max_iter
	mu := float64(0.5)
	if a >= b {
		return a, a, errors.New("parameters a and b out of order")
	}
	if tol.Judge(a, b) || fa == 0 || fb == 0 {
		max_iter = 0
		if fa == 0 {
			b = a
		} else if fb == 0 {
			a = b
		}
		return a, b, nil
	}
	if (fa < 0 && fb < 0) || (fa >= 0 && fb >= 0) {
		return a, a, errors.New("parameters a and b do not bracket the root")
	}
	if fa != 0 {
		c = secantInterpolate(a, b, fa, fb)
		a, b, fa, fb, d, fd = bracket(f, a, b, c, fa, fb)
		count--
		if count > 0 && fa != 0 && !tol.Judge(a, b) {
			c = quadraticInterpolate(a, b, d, fa, fb, fd, 2)
			e = d
			fe = fd
			a, b, fa, fb, d, fd = bracket(f, a, b, c, fa, fb)
			count--
		}
	}
	for count > 0 && fa != 0 && !tol.Judge(a, b) {
		// save our brackets:
		a0 = a
		b0 = b

		min_diff := minFloat * 32
		prof := (math.Abs(fa-fb) < min_diff) || (math.Abs(fa-fd) < min_diff) || (math.Abs(fa-fe) < min_diff) || (math.Abs(fb-fd) < min_diff) || (math.Abs(fb-fe) < min_diff) || (math.Abs(fd-fe) < min_diff)
		if prof {
			c = quadraticInterpolate(a, b, d, fa, fb, fd, 2)
		} else {
			c = cubicInterpolate(a, b, d, e, fa, fb, fd, fe)
		}

		e = d
		fe = fd
		a, b, fa, fb, d, fd = bracket(f, a, b, c, fa, fb)
		count--
		if (0 == count) || (fa == 0) || tol.Judge(a, b) {
			break
		}
		//
		// Now another interpolated step:
		//
		prof = (math.Abs(fa-fb) < min_diff) || (math.Abs(fa-fd) < min_diff) || (math.Abs(fa-fe) < min_diff) || (math.Abs(fb-fd) < min_diff) || (math.Abs(fb-fe) < min_diff) || (math.Abs(fd-fe) < min_diff)
		if prof {
			c = quadraticInterpolate(a, b, d, fa, fb, fd, 3)
		} else {
			c = cubicInterpolate(a, b, d, e, fa, fb, fd, fe)
		}
		//
		// Bracket again, and check termination condition, update e:
		//
		a, b, fa, fb, d, fd = bracket(f, a, b, c, fa, fb)
		count--
		if (0 == count) || (fa == 0) || tol.Judge(a, b) {
			break
		}
		//
		// Now we take a double-length secant step:
		//
		if math.Abs(fa) < math.Abs(fb) {
			u = a
			fu = fa
		} else {
			u = b
			fu = fb
		}
		c = u - 2*(fu/(fb-fa))*(b-a)
		if math.Abs(c-u) > (b-a)/2 {
			c = a + (b-a)/2
		}
		//
		// Bracket again, and check termination condition:
		//
		e = d
		fe = fd
		a, b, fa, fb, d, fd = bracket(f, a, b, c, fa, fb)
		count--
		if (0 == count) || (fa == 0) || tol.Judge(a, b) {
			break
		}
		//
		// And finally... check to see if an additional bisection step is
		// to be taken, we do this if we're not converging fast enough:
		//
		if (b - a) < mu*(b0-a0) {
			continue
		}
		//
		// bracket again on a bisection:
		//
		e = d
		fe = fd
		a, b, fa, fb, d, fd = bracket(f, a, b, a+(b-a)/2, fa, fb)
		count--
	}

	max_iter -= count
	if fa == 0 {
		b = a
	} else if fb == 0 {
		a = b
	}
	return a, b, nil
}
